import sys
import os
import torch
import itertools
import argparse
import pickle as pkl
import random
import torch
import math
import json
import string
import logging
import numpy as np
import re
from tqdm import tqdm
from collections import Counter, defaultdict
from utils.dataset_utils import get_dataset

from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer, RobertaModel, GPTJForCausalLM
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM
from transformers import pipeline
from sentence_transformers import SentenceTransformer

def extract_words(target):
    words = re.findall(r"[a-zA-Z_'][a-zA-Z_0-9']*(?=[(),])", target)
    return words

def create_embedding(answers, word_to_idx):
    num_features = len(word_to_idx)
    num_samples = len(answers)
    
    tensor = torch.zeros((num_samples, num_features))
    
    for i, answer in enumerate(answers):
        words = extract_words(answer)
        word_counts = Counter(words)
        for word, count in word_counts.items():
            if word in word_to_idx:
                tensor[i, word_to_idx[word]] = count
    
    norm = torch.norm(tensor, dim=1, keepdim=True)
    norm = torch.where(norm == 0, torch.ones_like(norm), norm)
    tensor = tensor / norm
    
    return tensor

def main(args):
    if("local_structure" not in args.model_name):
        raise KeyError("No embedding for your input！")

    train_set, train_answer, test_set, test_answer = get_dataset(dataset=args.dataset, load_from_local=True)
    
    if(args.model_name == "local_structure_1"):
        print('train_set: ')
        for i in range(len(train_answer)):
            print(f"{i}: {train_answer[i]}")
        print('test_set: ')
        for i in range(len(test_answer)):
            print(f"{i}: {test_answer[i]}")
        
        word_set = set()
        all_answers = train_answer + test_answer
        
        for answer in all_answers:
            words = extract_words(answer)
            word_set.update(words)
        
        word_to_idx = {word: idx for idx, word in enumerate(sorted(word_set))}

        for key, value in word_to_idx.items():
            print(f"{key}: {value}")
        
        train_embedding = create_embedding(train_answer, word_to_idx)
        test_embedding = create_embedding(test_answer, word_to_idx)
        
        print('train_embedding: ')
        for i in range(len(train_embedding)):
            print(f"{i}: {train_embedding[i]}")
        print('test_embedding: ')
        for i in range(len(test_embedding)):
            print(f"{i}: {test_embedding[i]}")
        
        torch.save(train_embedding, f'./data/{args.dataset}/{args.dataset}_train_{args.model_name}.pt')
        torch.save(test_embedding, f'./data/{args.dataset}/{args.dataset}_test_{args.model_name}.pt')
        
        print(f"Feature dimensions (number of different character combinations): {len(word_to_idx)}")
        print(f"Number of training samples: {len(train_answer)}")
        print(f"Number of test samples: {len(test_answer)}")
        print(f"Saved to train.pt and test.pt")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--model_name", type=str, default="local_structure_1") 
    parser.add_argument("--dataset", type=str, default='geo880-standard')

    args = parser.parse_args()

    main(args)